{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/01-query-gen.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/01-query-gen.ipynb)\n",
    "\n",
    "We should have the CORD-19 *document_parses/pdf_json* directory locally, we will be reading text data from these files. First we define a generator to do this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "paths = [str(path) for path in Path('document_parses/pdf_json').glob('*.json')]\n",
    "\n",
    "def get_text():\n",
    "    passage_dupes = set()\n",
    "    for path in paths:\n",
    "        with open(path, 'r') as fp:\n",
    "            doc = json.load(fp)\n",
    "        # extract the passages of text from each document\n",
    "        body_text = [line['text'] for line in doc['body_text']]\n",
    "        # loop through and yield one passage at a time\n",
    "        for passage in body_text:\n",
    "            # check it is not a duplicate first\n",
    "            if passage in passage_dupes:\n",
    "                pass\n",
    "            else:\n",
    "                passage_dupes.add(passage)\n",
    "                yield passage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test the generator..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Until recently, seven types of coronaviruses had been reported to cause infections in humans [1] . Coronaviruses can use animal hosts and then can evolve to infect humans. This process is thought to explain the emergence of SARS-CoV (severe respiratory distress syndrome coronavirus) in 2003, MERS-CoV (Middle Eastern Respiratory Syndrome) in 2012, and SARS-CoV-2 in 2019. SARS-CoV-2 has 82% genome sequence similarity to SARS-CoV and 50% genome sequence homology to MERS-CoV. COVID-19 symptoms range from mild (fever, cough, or dyspnea) to moderate (respiratory failure requiring oxygen support) and can progress to ARDS (acute respiratory distress syndrome) and multiorgan failure. In one of the earlier studies, 80% of cases were mild, but the mortality rate ranged from 1.86% to 9.86% [2] . Higher mortality rates were reported in countries like Italy, possibly secondary to resource depletion in an overwhelmed health care system. Gastrointestinal symptoms such as diarrhea have been reported in approximately 2-10% of patients [3] , with a lower rate in China (3.7%) as compared to Singapore (17%) [4] . Liver injury has been reported in 60% of patients with SARS-CoV [5] and has also been reported in patients infected with MERS-CoV [6] . Studies also suggest that SARS-CoV-2 can affect the liver [7] [8] [9] . In a recent study published in Shanghai, 75 (50.7%) out of 148 patients were found to have elevated liver enzymes with SARS-CoV-2 [7] . In another study published in the Lancet in February 2020 by Huang et al., an increase in aspartate aminotransferase (AST) was observed in 62% in intensive care unit (ICU) patients compared to 25% in non-ICU patients, indicating that more severe disease correlates with worsening of liver enzymes [10] . Several other studies showed that liver injury in the form of an increase in AST/alanine aminotransferase (ALT) levels with a mild increase in bilirubin ranging from 14.8% to 53% [11] . In patients who died of SARS-CoV-2, liver injury was reported as high as 58.06% [12] . The highest levels recorded included an ALT of 7,590 U/L and an AST of 1,445 U/L [13] . Generally speaking, transaminase elevations are mild in patients with COVID-19. Here, we report a case of acute liver failure in an elderly patient with COVID-19 infection who did not have a history of preexisting liver disease.\n",
      "An 80-year-old male with a medical history of diabetes, hypertension, dyslipidemia, asthma, coronary artery disease with bypass graft, atrial fibrillation on warfarin, and heart failure with preserved ejection fraction with an automatic implantable cardiac defibrillator and pacemaker presented to the emergency department (ED) in March 2020 with intermittent fever, productive cough, and shortness of breath (SOB) for four to five days. He initially started noticing fever that was partially relieved by acetaminophen five days prior to presentation (maximum temperature of 102°F at home). This was associated with SOB on exertion, which progressed to SOB at rest and a productive cough. He denied any recent travel, contact with sick person, herbal medications use, or a recent change in home medications. His home medications included oral warfarin daily, oral metoprolol tartrate two times daily, oral metformin ER daily, oral aspirin daily, oral atorvastatin, and budesonide-formoterol inhaler twice daily. The review of systems was otherwise negative. The patient did not have a history of smoking, alcohol consumption, illicit drugs, or high-risk sexual behavior. Vitals at the time of admission showed a temperature of 101.6°F, heart rate of 80 beats/minute, blood pressure 140/70 of mm Hg, respiratory rate of 20 breaths/minute, and oxygen saturation of 98% on room air. Physical examination was positive for bilateral wheeze and rhonchi in all lung fields, 1+ pedal edema bilateral. His chest was without spider angiomas and abdomen with no hepatosplenomegaly, and he had no shifting dullness, with normoactive bowel sounds and no palmar erythema. On neurological examination, he was alert, oriented to time, place/person, followed commands, and had no focal deficits. Laboratory examination results are shown in Table 1 The patient had normal liver enzymes at presentation but had elevated transaminases on day 4. The examination at that time was negative for asterixis or encephalopathy. Atorvastatin was stopped, and the recommendation was made to start N acetylcysteine (NAC), and workup for acute and chronic liver disease was ordered. His respiratory status continued to deteriorate, requiring increased oxygen support. His radiologic findings worsened with enlarging infiltrates on a chest X-ray on day 4, as shown in Figures 1-3 . The patient then developed cytokine release syndrome (CRS) (elevated interleukin [IL]-6 and IL-10 as mentioned in Table 5 ), and he expired on day 9. \n",
      "COVID-19 is a pandemic illness that primarily affects the respiratory system with a wide spectrum of disease presentation that ranges from mild disease (fever, cough) to severe (ARDS, multiorgan failure). The gastrointestinal/hepatic systems are the next most commonly affected, with symptoms such as nausea, vomiting, diarrhea, and an increase in liver enzymes. Currently, studies on the exact pathophysiology of liver injury in these patients are limited, but it is believed either to be a direct effect of the virus or immune-mediated inflammatory response, such as CRS, hypoxemia, and failure of innate immune regulation, or to be drug-induced.\n"
     ]
    }
   ],
   "source": [
    "passages = get_text()\n",
    "\n",
    "for i, passage in enumerate(passages):\n",
    "    print(passage)\n",
    "    if i == 2: break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perfect, we have our passage generator function, now we need to create queries to pair with these passages. We do this using a T5 query generation model called `'doc2query/msmarco-t5-base-v1'`, which was trained on a pre-COVID dataset form 2018. The model still produces sensible COVID related queries by copying words from the input text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
    "\n",
    "model_name = 'doc2query/msmarco-t5-base-v1'\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test it on an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for passage in passages:\n",
    "    break\n",
    "\n",
    "# tokenize the passage\n",
    "inputs = tokenizer(passage, return_tensors='pt')\n",
    "# generate three queries\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs['input_ids'].cuda(),\n",
    "    attention_mask=inputs['attention_mask'].cuda(),\n",
    "    max_length=64,\n",
    "    do_sample=True,\n",
    "    top_p=0.95,\n",
    "    num_return_sequences=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Paragraph:\n",
      "1) It is postulated that both SARS-CoV-2 and SARS-CoV bind to angiotensin-converting enzyme 2 (ACE2) receptors to enter the target cell [14] where the virus replication begins and starts to infect cells of the upper respiratory tract. Based on the scRNA-seq data, Chai et al. [15] found that ACE 2 receptors also found in the hepatobiliary system (high in bile duct cells, cholangiocytes, when compared to liver cells). Cholangiocytes play a critical role in liver regeneration and immune responses [16] . Thus, the authors concluded that potential damage of cholangiocytes by 2019-nCoV might lead to profound consequences in the liver rather than the direct effect of the virus on hepatocytes.\n",
      "\n",
      "Generated Queries:\n",
      "1: where is sars infection found\n",
      "2: sars bacteria affects what organ\n",
      "3: what are the ace receptors that are associated with sars virus\n"
     ]
    }
   ],
   "source": [
    "print(\"Paragraph:\")\n",
    "print(passage)\n",
    "\n",
    "print(\"\\nGenerated Queries:\")\n",
    "for i in range(len(outputs)):\n",
    "    query = tokenizer.decode(outputs[i], skip_special_tokens=True)\n",
    "    print(f'{i + 1}: {query}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks great, we're getting what mostly look like intelligent queries that the passage answers. We will apply this at a larger scale, creating a large number *(query, passage)* pairs, which we will save to TSV files in a new *data/qp_pairs* directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if not os.path.exists('data'): os.mkdir('data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will restrict the number of pairs we create, as CORD-19 is a *massive* dataset. Change this threshold as you prefer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b55c2fa40c54bcbbb2ad16e05878fd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm  # this is our progress bar\n",
    "\n",
    "target = 200_000\n",
    "batch_size = 256\n",
    "num_queries = 3  # number of queries to generate for each passage\n",
    "count = 0\n",
    "lines = []\n",
    "passage_batch = []\n",
    "\n",
    "# reinitialize passage generator\n",
    "passages = get_text()\n",
    "\n",
    "with tqdm(total=target) as progress:\n",
    "    for passage in passages:\n",
    "        if count >= target: break\n",
    "        # remove tab + newline characters if present\n",
    "        passage_batch.append(passage.replace('\\t', ' ').replace('\\n', ' '))\n",
    "        \n",
    "        # we encode in batches\n",
    "        if len(passage_batch) == batch_size:\n",
    "            # tokenize the passage\n",
    "            inputs = tokenizer(\n",
    "                passage_batch,\n",
    "                truncation=True,\n",
    "                padding=True,\n",
    "                max_length=256,\n",
    "                return_tensors='pt'\n",
    "            )\n",
    "\n",
    "            # generate three queries per doc/passage\n",
    "            outputs = model.generate(\n",
    "                input_ids=inputs['input_ids'].cuda(),\n",
    "                attention_mask=inputs['attention_mask'].cuda(),\n",
    "                max_length=64,\n",
    "                do_sample=True,\n",
    "                top_p=0.95,\n",
    "                num_return_sequences=num_queries\n",
    "            )\n",
    "\n",
    "            # decode query to human readable text\n",
    "            decoded_output = tokenizer.batch_decode(\n",
    "                outputs,\n",
    "                skip_special_tokens=True\n",
    "            )\n",
    "\n",
    "            # loop through to pair query and passages\n",
    "            for i, query in enumerate(decoded_output):\n",
    "                query = query.replace('\\t', ' ').replace('\\n', ' ')  # remove newline + tabs\n",
    "                passage_idx = int(i/num_queries)  # get index of passage to match query\n",
    "                lines.append(query+'\\t'+passage_batch[passage_idx])\n",
    "                count += 1\n",
    "            \n",
    "            passage_batch = []\n",
    "            progress.update(len(decoded_output))\n",
    "\n",
    "# write (Q, P+) pairs to file\n",
    "with open('./data/pairs.tsv', 'w', encoding='utf-8') as fp:\n",
    "    fp.write('\\n'.join(lines))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now generating queries for our passages and can move onto the next step of GPL called *negative mining*."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "conda-root-py",
   "name": "common-cu110.m91",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
